Table of Contents
Welcome to SemiPy's documentation
Warning
SemiPy is not released yet. Stay tuned, the repository should be public soon at https://github.com/SemiPy.
A PyTorch based Python library dedicated to Semi-Supervised Learning (SSL).
Introduction
Welcome to the documentation of SemiPy, an open-source Python library designed specifically for Semi-Supervised Learning (SSL) using the power of PyTorch. The goal of SemiPy is to be a toolbox designed to tackle SSL experiments and real-world problematics. It includes different SSL methods, datasets and useful tools for SSL. With SemipY, we hope that you will be able to unlock the potential of SSL in your machine learning project.
See Getting Started for instructions on how to use the library and diverse tutorials.
See About for more informations about the team behind SemiPy and the used licence.
What is Semi-Supervised Learning ?
Semi-Supervised Learning (SSL) is a field of machine learning that addresses the challenges posed by limited labelled data by leveraging the potential hidden information of unlabelled data. SSL has emerged as a powerful machine learning strategy, bridging the gap between the efficiency of unsupervised learning and the performance of supervised learning. By using the unlabelled data, the model can learn from the underlying structure of the data distribution, even when explicit labels are unavailable.
About Semipy
Contributors
- Lucas Boiteau
- Pierre-Alexandre Mattei
- Quentin Oliveau
- Alexandre Gensse
- Aude Sportisse
- Hugo Schmutz
License
BSD 3-Clause License
Getting started ↵
How to use SemiPy ?
SemiPy has the advantage of being versatile: different ways exist to use the library.
Configuration file
A first simple way to use SemiPy is to write a configuration file that describes all the different parameters that the user can change. Here you can find the default configuration file used by the library. Below, a list of all parameters and their description.
| Parameter | Description | Default value |
|---|---|---|
| USE_LIGHTNING | To use PyTorch Lightning or not | True |
| EPOCHS | Total number of epochs. One epoch is finished when the model has seen all labelled items. | 1 |
| BALANCING_WEIGHT | Corresponds to the 'lambda' parameter found in various SSL papers. Allows to weight unlabelled loss compared to labelled one. | 0.5 |
| DEBIASED | To enable safe SSL via debiased loss (Schmutz et. al) | False |
| SELECTION_THRESHOLD | Probability threshold for pseudo-labels. | 0.95 |
| BATCH_SIZE | Size of each batch. | 64 |
| LABELLED_PROPORTION | Proportion of labelled items compared to unlabelled ones. Needed for JointSampler | 0.5 |
| SAVE_PATH | Path to save models at the end of training and best model during training. | './saves' |
| OPTIMIZER | Optimizer name and parameters | NAME: 'SGD'; PARAMS: {lr: 1.0e-3, momentum: 0.9} |
| SCHEDULER | Learning rate scheduler | null |
| NET | Model to train. | 'resnet18' |
| METHOD | SSL method to use. | 'pseudolabel' |
| NUM_WARMUP_EPOCHS | Number of warmup epochs. Used for PiModel. | null |
| DATA | Data informations | See default values in DATA details below |
| USE_MULTIGPU | To enable multi-GPU training. | False |
| NUM_GPU | Number of used GPU(s). | null |
| MULTIGPU_STRATEGY | Strategy for multi-GPU training. For now only 'ddp' is supported. | null |
| EMA | Exponential Moving Average coefficient for EMA on model's parameters. | null |
| METRICS | Metrics informations | See default values in METRICS details below |
| EARLYSTOPPING | EarlyStopping informations | See default values in EARLYSTOPPING details below |
DATA parameters details
| Parameter | Description | Default value |
|---|---|---|
| NAME | Name of the desired dataset. Use 'custom' for your own dataset. | null |
| VALIDATION_PROPORTION | Size proportion of validation set compared to labelled items. | null |
| TEST_PROPORTION | (In case current dataset does not have a test set) Size proportion for test set compared to whole dataset. | null |
| LABELLED_SAMPLES | Number of labelled samples in training set. | null |
| UNLABELLED_SAMPLES | Number of unlabelled samples in training set. | null |
| INCLUDE_LABELLED | To include labelled items (without label) in the unlabelled set, to add information. | True |
| USE_EXTRA | Used by SVHN dataset | False |
| DATA/SPLITS | Subsection for defining different data splits. | See below for parameters |
| SPLITS/TRAIN | Subsection example for Train split. | |
| SPLITS/PATH | Path to train set. | 'data' |
| SPLITS/NAME_UNLABELLED | In case using your own dataset : name of folder containing unlabelled items. | 'nodata' |
| SPLITS/TRANSFORMS | To add transformations to the dataset. | [] |
METRICS parameters details
| Parameter | Description | Default value |
|---|---|---|
| METRICS/VALIDATION | Subsection for validation metrics. | See below for parameters. |
| VALIDATION/NAME | Name of the first validation metric. | Accuracy |
| VALIDATION/PARAMS | Parameters for the corresponding metric. | {task: multiclass} |
| METRICS/TEST | Subsection for test metrics. | See below for parameters. |
| TEST/NAME | Name of the first test metric. | Accuracy |
| TEST/PARAMS | Parameters for the corresponding metric. | {task: multiclass} |
Tip
You can add as many metrics as you want. Simply add a new item in the list. The names should be took in torchmetrics list of metrics. For example:
EARLYSTOPPING parameters details
| Parameter | Description | Default value |
|---|---|---|
| EARLYSTOPPING/NAME | Name of the monitored metric. | VALIDATION/Loss |
| EARLYSTOPPING/PARAMS | Dictionary of parameters for EarlyStopping. | {'mode': 'min', 'patience': 10} |
Once you have your configuration file ready, you have multiple choices: use a notebook or use a script.
Notebook Usage
If you want to use a notebook, simply feed the configuration file's path to your trainer.
from pytorch_lightning import Trainer
import semipy as smp
trainer = Trainer(max_epochs=100, accelerator='gpu')
lightning_module = smp.pl.LitFixMatch(config='config.yaml')
trainer.fit(lightning_module)
Script usage
You can also use a custom script, or use the main.py script present in the root folder of the library. Just use the parser -config "path_to_config_file" option to use your configuration file.
Without configuration file ?
It is stil possible to use SemiPy without a YAML configuration file. You can also use a dictionary of parameters. Note that when building a trainer (with or without PyTorch Lightning), the library will still use the above provided default configuration file in order to fill in the parameters that have not been specified by the user. For example with a dictionary:
import semipy as smp
args = {'EPOCHS': 100, 'BALANCING_WEIGHT': 0.12}
trainer = smp.tools.SSLTrainer(config=args)
trainer.fit()
Finally, if you don't want to use the provided trainers, you have access to all the useful functions that comes with SemiPy, especially the different loss functions specific to each Semi-Supervised learning algorithms (that are written in a PyTorch style), and of course the JointSampler. It's up to you to build your own code with those more in-depth functions. Note that most of the code is Pytorch Lightning compatible, so you can also build your own Lightning Module.
MedMNIST Tutorial¶
In this tutorial, we will start a simple training with FixMatch on 'pathmnist', a 9-classes dataset from MedMNIST composed of 89,996 training samples. We will use only 6,000 samples as labelled and the rest will be unlabelled. This dataset also comes with a validation set of 10,004 images and a test set of 7,180 images.
import semipy as smp
import torch
To simplify things, we will use a configuration file that you can find here. That way, it is easier to define parameters for our training.
args = smp.tools.get_config('config.yaml')
Then, after reading our config file with 'get_config', we will then retrieve our datasets. Note than all those steps in this tutorial can be done automatically by using the "SSLTrainer" class defined in SemiPy. But as this is a tutorial, it is better to detail everything. Hence, in the next cell in 'get_medmnist', we have to manually define the number of labelled samples and validation proportion, even if those parameters are present in the configuration file. Also note that 'augmentation' is set to True because we need to use strong augmentation for FixMatch training.
sets = smp.datasets.get_medmnist(name='pathmnist', num_labelled=6000, augmentation=True, include_labelled=True)
C:\Users\lboiteau\Documents\Demos\semipy\datasets\medmnist.py:44: UserWarning: Warning: valid_proportion is set to 0 or not defined. Length of validation set will be the length of the original set from MedMNIST
warnings.warn('Warning: valid_proportion is set to 0 or not defined. '
Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz Using downloaded and verified file: C:\Users\lboiteau\.medmnist\pathmnist.npz
Next, let's download a model to train. We will choose a simple resnet18 from PyTorch.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = smp.models.get_model('resnet18', num_classes=9)
model = model.to(device)
We choose an SGD optimizer with a learning rate of 0.03.
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
Next is a very important part: the creation of the dataloader. In a usual training process, we would just use a simple dataloader. Here, we need to separate labelled and unlabelled samples. That's why we will use the JointSampler from SemiPy. It allows to use only one dataset composed of both labelled and unlabelled items. Simply choose a batch size and the proportion of labelled items you want in each batch and you are good to go !
sampler = smp.sampler.JointSampler(dataset=sets['Train'], batch_size=64, proportion=0.5)
dataloader = torch.utils.data.DataLoader(sets['Train'], batch_sampler=sampler)
val_dataloader = torch.utils.data.DataLoader(sets['Validation'], batch_size=64, shuffle=False)
Finally, it's time to choose an SSL method for training. We will choose FixMatch, and as we are not using PyTorch Lightning in this tutorial, we will use the simple trainer included in SemiPy:
trainer = smp.methods.FixMatch(args, model, dataloader, val_dataloader, optimizer, scheduler=None, num_classes=9)
trainer.training()
0%| | 0/25 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
Iterations: 0%| | 0/188 [00:00<?, ?it/s]
[1.6222473915587081, 1.1791775106115545, 1.0396045560532428, 0.9519544747915674, 0.8916148728829749, 0.8484063008998303, 0.8072013575980004, 0.7802366132431842, 0.7285392368410496, 0.7350093584428442, 0.6964399213803575, 0.6474804099886975, 0.6309603810944455, 0.6154432499662359, 0.6662413840915294, 0.5623518959321874, 0.5926185534038442, 0.5563194929285252, 0.5360793276353085, 0.5329677575921759, 0.5152371287187363, 0.5018856583282034, 0.47833075889564575, 0.47020974707730273, 0.44847835243699397]
trainer.eval(0, 0)
{'VALIDATION/MulticlassAccuracy': tensor(0.2592),
'VALIDATION/Loss': tensor(1.9023)}
Two circles demo¶
This demo is using the two circles datset from Scikit-learn. It is made for showing some SSL methods on some simple data with visual comprehension. Those experiments are made with SemiPy.
from sklearn.datasets import make_circles, make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib.colors as cl
import seaborn as sns
import numpy as np
import math
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
import torch.nn.functional as F
import random
import semipy
Data acquisition¶
X, y = make_circles(n_samples=2000, noise=0.07)
X_train, X_val, y_train, y_val = train_test_split(X, y)
indices = list(range(len(y_train)))
unlabelled_indices = random.sample(indices, int(0.98*len(y_train)))
y_train[unlabelled_indices] = -1
Complete case¶
# Getting only labelled items
labelled_indices = list(set(indices)-set(unlabelled_indices))
X_train_cc = X_train[labelled_indices]
y_train_cc = y_train[labelled_indices]
# Converting to tensors
X_train_cc = torch.from_numpy(X_train_cc).to(torch.float32)
y_train_cc = torch.from_numpy(y_train_cc).to(torch.int64)
X_val_cc = torch.from_numpy(X_val).to(torch.float32)
y_val_cc = torch.from_numpy(y_val).to(torch.int64)
plt.figure(figsize=(10, 10))
plt.title("Dataset")
plt.scatter(X_train_cc[:, 0], X_train_cc[:, 1], c=y_train_cc, cmap=cl.ListedColormap(['orange', 'blue']))
plt.show()
# Creating datasets
train_dataset = TensorDataset(X_train_cc, y_train_cc)
val_dataset = TensorDataset(X_val_cc, y_val_cc)
# Creating dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16)
val_dataloader = [DataLoader(val_dataset, batch_size=16)]
# Defining model
model = nn.Sequential(
nn.Linear(2, 30),
nn.ReLU(),
nn.Linear(30, 20),
nn.ReLU(),
nn.Linear(20, 2)
).to('cuda')
# Defining optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.04, momentum=0.9)
# Defining parameters
args = semipy.tools.get_config('config.yaml')
# Using semipy's PseudoLabel class to use PseudoLabel method
# As we only have labelled items in the dataloaders, the complete will be computed instead
# of real PseudoLabel
method = semipy.methods.CompleteCase(args, model, train_dataloader, val_dataloader, optimizer, scheduler=None, num_classes=2)
results = method.training()
0%| | 0/100 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
def show_separation(model, save=False, name_to_save=""):
sns.set(style="white")
model = model.to('cpu')
xx, yy = np.mgrid[-1.5:1.5:.01, -1.5:1.5:.01]
grid = np.c_[xx.ravel(), yy.ravel()]
batch = torch.from_numpy(grid).type(torch.float32)
with torch.no_grad():
probs, indices = torch.max(torch.softmax(model(batch), dim=1), dim=1)
for i, (prob, index) in enumerate(zip(probs, indices)):
if index.item() == 0:
probs[i] = torch.tensor(1-prob.item())
probs = probs.numpy().reshape(xx.shape)
f, ax = plt.subplots(figsize=(12, 10))
ax.set_title("Decision boundary", fontsize=14)
contour = ax.contourf(xx, yy, probs, 15, cmap="RdBu",
vmin=0, vmax=1)
ax_c = f.colorbar(contour)
ax_c.set_label("$P(y = 1)$")
#ax_c.set_ticks([0, .25, .5, .75, 1])
ax.scatter(X[100:,0], X[100:, 1], c=y[100:], s=40,
cmap="RdBu", vmin=-.2, vmax=1.2,
edgecolor="white", linewidth=1)
ax.set(xlabel="$X_1$", ylabel="$X_2$")
if save:
plt.savefig(name_to_save)
else:
plt.show()
plt.plot(range(len(results)), results)
[<matplotlib.lines.Line2D at 0x23817f74bb0>]
show_separation(model)
PseudoLabel¶
# Converting to tensors
X_train_t = torch.from_numpy(X_train).to(torch.float32)
y_train_t = torch.from_numpy(y_train).to(torch.int64)
X_val_t = torch.from_numpy(X_val).to(torch.float32)
y_val_t = torch.from_numpy(y_val).to(torch.int64)
plt.figure(figsize=(10, 10))
plt.title("Dataset")
plt.scatter(X_train_t[:, 0], X_train_t[:, 1], c=y_train, cmap=cl.ListedColormap(['darkgray', 'orange', 'blue']))
plt.show()
# Creating datasets
train_dataset = TensorDataset(X_train_t, y_train_t)
val_dataset = TensorDataset(X_val_t, y_val_t)
# Using custom sampler from semipy
sampler = semipy.sampler.JointSampler(train_dataset, batch_size=32, proportion=0.5)
# Defining dataloaders using the created sampler
train_dataloader = DataLoader(train_dataset, batch_sampler=sampler)
val_dataloader = [DataLoader(val_dataset, batch_size=32)]
# Defining model
model = nn.Sequential(
nn.Linear(2, 30),
nn.ReLU(),
nn.Linear(30, 20),
nn.ReLU(),
nn.Linear(20, 2)
).to('cuda')
# Defining optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.04, momentum=0.9)
# Defining parameters
args = semipy.tools.get_config('config.yaml')
# Using semipy's PseudoLabel class
method = semipy.methods.PseudoLabel(args, model, train_dataloader, val_dataloader, optimizer, scheduler=None, num_classes=2)
results = method.training()
0%| | 0/100 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
plt.plot(range(len(results)), results)
[<matplotlib.lines.Line2D at 0x2386a977cd0>]
show_separation(model)
FixMatch¶
With FixMatch method, we need to define two types of data augmentations : a weak and a strong one. Here we define a rotation of random angle as the weak augmentation, and a gaussian noise as the strong augmentation.
class Rotation:
""" Code for rotation transform on 2D points. """
def __call__(self, coord):
angle = random.randint(0, 180)
phi = torch.tensor(angle*math.pi/180)
s = torch.sin(phi)
c = torch.cos(phi)
rotation_matrix = torch.stack([torch.stack([c, -s]),
torch.stack([s, c])])
coord = coord @ rotation_matrix
return coord
class GaussianNoise:
""" Code for adding gaussian noise on 2D points. """
def __init__(self, mean=0., std=1.):
self.mean = mean
self.std = std
def __call__(self, coord):
return coord + torch.randn(coord.size()) * self.std + self.mean
# Using semipy's CustomDataset two apply two sorts of augmentations
aug_train_dataset = semipy.datasets.SSLDataset(train_dataset, weak_transform=Rotation(), strong_transform=GaussianNoise(std=0.01))
# Using custom sampler from semipy
sampler = semipy.sampler.JointSampler(train_dataset, batch_size=32, proportion=0.5)
# Defining dataloaders using the created sampler
train_dataloader = DataLoader(aug_train_dataset, batch_sampler=sampler)
val_dataloader = [DataLoader(val_dataset, batch_size=32)]
# Defining model
model = nn.Sequential(
nn.Linear(2, 30),
nn.ReLU(),
nn.Linear(30, 20),
nn.ReLU(),
nn.Linear(20, 2)
).to('cuda')
# Defining optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.04, momentum=0.9)
# Defining parameters
args = semipy.tools.get_config('config.yaml')
# Using semipy's PseudoLabel class
method = semipy.methods.FixMatch(args, model, train_dataloader, val_dataloader, optimizer, scheduler=None, num_classes=2)
results = method.training()
0%| | 0/100 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
show_separation(model)
plt.plot(range(len(results)), results)
[<matplotlib.lines.Line2D at 0x2386a99fcd0>]
Complete Case with data augmentation¶
# Using custom sampler from semipy
sampler = semipy.sampler.JointSampler(aug_train_dataset, batch_size=16, proportion=1)
# Defining dataloaders using the created sampler
train_dataloader = DataLoader(aug_train_dataset, batch_sampler=sampler)
val_dataloader = [DataLoader(val_dataset, batch_size=16)]
C:\Users\lboiteau\Documents\Demos\semipy\sampler\jointsampler.py:35: UserWarning: Warning : you are in the complete case. All items will be labelledand no unlabelled items will be seen by the model.
warnings.warn(("Warning : you are in the complete case. All items will be labelled"
# Defining model
model = nn.Sequential(
nn.Linear(2, 30),
nn.ReLU(),
nn.Linear(30, 20),
nn.ReLU(),
nn.Linear(20, 2)
).to('cuda')
# Defining optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.04, momentum=0.9)
# Defining parameters
args = semipy.tools.get_config('config.yaml')
# Using semipy's PseudoLabel class
method = semipy.methods.FixMatch(args, model, train_dataloader, val_dataloader, optimizer, scheduler=None, num_classes=2)
results = method.training()
0%| | 0/100 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
Iterations: 0%| | 0/2 [00:00<?, ?it/s]
plt.plot(range(len(results)), results)
[<matplotlib.lines.Line2D at 0x23890c9c550>]
show_separation(model)
JointSampler tutorial¶
In this tutorial, we will show you how easily it is to use the JointSampler provided in SemiPy. This sampler is ideal for Semi-Supervised Learning tasks, as it is creating batches of both labelled and unlabelled items, with respect to user-specified parameters such as 'batch_size' and 'proportion'. The parameter 'proportion' corresponds to the proportion of labelled items that should be included in each batch. For example, if the user choose a proportion of 0.3, then 30% of the batches items will be labelled.
import semipy as smp
import numpy as np
from torch.utils.data import DataLoader
To make a really simple example, let's use fake tabular data in order to really understand how the JointSampler works. The most important point to respect in order to make the sampler work properly is the data itself: unlabelled data should be labelled as -1. This is the only constraint.
Let's use a 3-class dataset of length 20 with items from number 7 to 20 unlabelled (with label -1) :
data = np.array([
(1, 1),
(2, 1),
(3, 1),
(4, 2),
(5, 3),
(6, 3),
(7, -1),
(8, -1),
(9, -1),
(10, -1),
(11, -1),
(12, -1),
(13, -1),
(14, -1),
(15, -1),
(16, -1),
(17, -1),
(18, -1),
(19, -1),
(20, -1)
])
Using SSLDataset class from SemiPy to convert the numpy array into a 'callable' dataset:
dataset = smp.datasets.SSLDataset(data)
Now we can easily 'call' any item in the dataset, for example:
dataset[0]
(1, 1)
Let's define our sampler. As it is considered as a 'batch_sampler', it is here that we choose our parameters such as batch size and proportion.
sampler = smp.sampler.JointSampler(dataset=dataset, batch_size=4, proportion=0.5)
Now in order to simulate some learnin epochs, let's use the sampler in a torch DataLoader:
dataloader = DataLoader(dataset, batch_sampler=sampler)
We chose a batch size of 4 and a proportion of 0.5 so we expect to have at each epoch 4 items with 2 labelled and 2 unlabelled. Let's see:
# We simulate 5 epochs
for epoch in range(5):
print(f'----------Epoch {epoch}----------')
for x, y in dataloader:
print(x)
----------Epoch 0---------- tensor([1, 2, 7, 8], dtype=torch.int32) tensor([ 3, 4, 9, 10], dtype=torch.int32) tensor([ 5, 6, 11, 12], dtype=torch.int32) ----------Epoch 1---------- tensor([ 3, 5, 13, 14], dtype=torch.int32) tensor([ 4, 6, 15, 16], dtype=torch.int32) tensor([ 1, 2, 17, 18], dtype=torch.int32) ----------Epoch 2---------- tensor([ 1, 3, 19, 20], dtype=torch.int32) tensor([ 5, 2, 17, 13], dtype=torch.int32) tensor([ 6, 4, 10, 8], dtype=torch.int32) ----------Epoch 3---------- tensor([ 5, 2, 14, 16], dtype=torch.int32) tensor([ 6, 1, 18, 9], dtype=torch.int32) tensor([ 3, 4, 20, 12], dtype=torch.int32) ----------Epoch 4---------- tensor([ 6, 1, 19, 11], dtype=torch.int32) tensor([ 3, 5, 15, 7], dtype=torch.int32) tensor([ 4, 2, 14, 11], dtype=torch.int32)
As we can see, each epoch is divided in 3 batches. In each batch we correctly have a proportion of 0.5 labelled samples. We also correctly see every labelled item at each epoch. As there is much more unlabelled samples than labelled ones, we can't see them all in one peoch so the sampler keeps in memory where it stopped sampling unlabelled items at every last batch of an epoch and continues to sample the next unseen unlabelled samples in the next epoch.
Note than when every sample of one subset (labelled or unlabelled) have been seen, those are shuffled (as we can see after the first epoch with labelled items, or after the first batch of Epoch 2 with unlabelled items).

Another thing to notice is about proportion. If we had set 'proportion=0.6', than we would expect 4*0.6=2.4 labelled samples per batch, which would be rounded to 2. It would not change the behavior of the sampler. But if we chose 'proportion=0.65', than we woudl expect 4*0.65=2.6 labelled samples per batch, which would be rounded to 3. Let's see if this if true:
# We first try with proportion=0.6
sampler = smp.sampler.JointSampler(dataset=dataset, batch_size=4, proportion=0.6)
dataloader = DataLoader(dataset, batch_sampler=sampler)
# We simulate 5 epochs
for epoch in range(5):
print(f'----------Epoch {epoch}----------')
for x, y in dataloader:
print(x)
----------Epoch 0---------- tensor([1, 2, 7, 8], dtype=torch.int32) tensor([ 3, 4, 9, 10], dtype=torch.int32) tensor([ 5, 6, 11, 12], dtype=torch.int32) ----------Epoch 1---------- tensor([ 4, 1, 13, 14], dtype=torch.int32) tensor([ 5, 2, 15, 16], dtype=torch.int32) tensor([ 3, 6, 17, 18], dtype=torch.int32) ----------Epoch 2---------- tensor([ 4, 5, 19, 20], dtype=torch.int32) tensor([ 2, 6, 16, 14], dtype=torch.int32) tensor([ 1, 3, 9, 15], dtype=torch.int32) ----------Epoch 3---------- tensor([ 5, 1, 10, 18], dtype=torch.int32) tensor([ 3, 4, 20, 19], dtype=torch.int32) tensor([ 2, 6, 7, 12], dtype=torch.int32) ----------Epoch 4---------- tensor([ 5, 3, 11, 17], dtype=torch.int32) tensor([ 2, 6, 8, 13], dtype=torch.int32) tensor([ 1, 4, 14, 15], dtype=torch.int32)
As we can see it didn't change anything. But if we try with proportion=0.65, we obtain only two batches per epoch because we now have 3 labelled items per batch.
# We first try with proportion=0.6
sampler = smp.sampler.JointSampler(dataset=dataset, batch_size=4, proportion=0.65)
dataloader = DataLoader(dataset, batch_sampler=sampler)
# We simulate 5 epochs
for epoch in range(5):
print(f'----------Epoch {epoch}----------')
for x, y in dataloader:
print(x)
----------Epoch 0---------- tensor([1, 2, 3, 7], dtype=torch.int32) tensor([4, 5, 6, 8], dtype=torch.int32) ----------Epoch 1---------- tensor([3, 2, 5, 9], dtype=torch.int32) tensor([ 6, 4, 1, 10], dtype=torch.int32) ----------Epoch 2---------- tensor([ 4, 2, 5, 11], dtype=torch.int32) tensor([ 6, 1, 3, 12], dtype=torch.int32) ----------Epoch 3---------- tensor([ 6, 2, 5, 13], dtype=torch.int32) tensor([ 1, 3, 4, 14], dtype=torch.int32) ----------Epoch 4---------- tensor([ 4, 1, 5, 15], dtype=torch.int32) tensor([ 2, 6, 3, 16], dtype=torch.int32)
Finally, it is good to know that the sampler is also capable to performs full batches of labelled or unlabelled items. Simply adjust the proportion.
# Compelte case (full labelled)
sampler = smp.sampler.JointSampler(dataset=dataset, batch_size=4, proportion=1.0)
dataloader = DataLoader(dataset, batch_sampler=sampler)
# We simulate 5 epochs
for epoch in range(5):
print(f'----------Epoch {epoch}----------')
for x, y in dataloader:
print(x)
----------Epoch 0---------- tensor([1, 2, 3, 4], dtype=torch.int32) tensor([5, 6], dtype=torch.int32) ----------Epoch 1---------- tensor([5, 3, 4, 2], dtype=torch.int32) tensor([6, 1], dtype=torch.int32) ----------Epoch 2---------- tensor([2, 5, 4, 6], dtype=torch.int32) tensor([3, 1], dtype=torch.int32) ----------Epoch 3---------- tensor([4, 5, 2, 3], dtype=torch.int32) tensor([1, 6], dtype=torch.int32) ----------Epoch 4---------- tensor([4, 3, 6, 5], dtype=torch.int32) tensor([2, 1], dtype=torch.int32)
C:\Users\lboiteau\Documents\Demos\semipy\sampler\jointsampler.py:35: UserWarning: Warning : you are in the complete case. All items will be labelledand no unlabelled items will be seen by the model.
warnings.warn(("Warning : you are in the complete case. All items will be labelled"
# Full unlabelled (proportion to 0)
sampler = smp.sampler.JointSampler(dataset=dataset, batch_size=4, proportion=0.0)
dataloader = DataLoader(dataset, batch_sampler=sampler)
# We simulate 5 epochs
for epoch in range(5):
print(f'----------Epoch {epoch}----------')
for x, y in dataloader:
print(x)
----------Epoch 0---------- tensor([ 7, 8, 9, 10], dtype=torch.int32) tensor([11, 12, 13, 14], dtype=torch.int32) tensor([15, 16, 17, 18], dtype=torch.int32) tensor([19, 20], dtype=torch.int32) ----------Epoch 1---------- tensor([16, 13, 10, 15], dtype=torch.int32) tensor([11, 14, 17, 18], dtype=torch.int32) tensor([ 9, 7, 12, 19], dtype=torch.int32) tensor([ 8, 20], dtype=torch.int32) ----------Epoch 2---------- tensor([10, 14, 17, 20], dtype=torch.int32) tensor([ 9, 7, 15, 16], dtype=torch.int32) tensor([18, 8, 11, 12], dtype=torch.int32) tensor([13, 19], dtype=torch.int32) ----------Epoch 3---------- tensor([12, 11, 13, 17], dtype=torch.int32) tensor([ 8, 9, 10, 7], dtype=torch.int32) tensor([15, 16, 19, 20], dtype=torch.int32) tensor([14, 18], dtype=torch.int32) ----------Epoch 4---------- tensor([ 7, 14, 19, 18], dtype=torch.int32) tensor([10, 8, 17, 9], dtype=torch.int32) tensor([12, 20, 13, 16], dtype=torch.int32) tensor([15, 11], dtype=torch.int32)
C:\Users\lboiteau\Documents\Demos\semipy\sampler\jointsampler.py:39: UserWarning: Warning : you chose a proportion of 0. All items will be unlabelledand no labelled items will be seen by the model.
warnings.warn(("Warning : you chose a proportion of 0. All items will be unlabelled"
Ended: Getting started
API ↵
semipy.datasets
Warning
This section is in construction.
This module allows user to get datasets for SSL. It is possible to download and transform famous datasets used when benchmarking in SSL but also to split a custom dataset into a labelled/unlabelled set.
Classes
| API Reference | Description |
|---|---|
| semipy.datasets.utils.SSLDataset | A class that allows to transform a simple dataset into an SSL dataset with weakly and strongly augmented data. |
| semipy.datasets.utils.SSLDatasetFolder | Same class as SSLDataset but dedicated to multi-GPU training with DDP. |
Functions
| API Reference | Description |
|---|---|
| semipy.datasets.get_dataset | To call the right getter function for the right dataset. |
| semipy.datasets.cifar.get_cifar | To download CIFAR with or without SSL constraints. |
| semipy.datasets.svhn.get_svhn | To download SVHN with or without SSL constraints. |
| semipy.datasets.stl10.get_stl10 | To download STL10 with or without SSL constraints. |
| semipy.datasets.medmnist.get_medmnist | To download one of MedMNIST datasets with or without SSL constraints. |
| semipy.datasets.utils.split_dataset | To split a SSLDataset object into train, test and validation sets. |
| semipy.datasets.utils.build_transforms | To build a composition of transformations from a list of transforms configurations. |
| semipy.datasets.utils.build_augmentations | To read weak and strong augmentations from yaml file. |
semipy.methods
Warning
This section is in construction.
This module implements different SSL methods. The differences are in the loss computations that differ from one method to another.
Classes
| API Reference | Description |
|---|---|
| semipy.methods.abstractMethod | Base class used to build SSL methods without PyTorch Lightning. |
| semipy.methods.CompleteCase | Applies complete case. |
| semipy.methods.PseudoLabel | Applies PseudoLabel method. |
| semipy.methods.FixMatch | Applies FixMatch method. |
| semipy.methods.PiModel | Applies PiModel method. |
| semipy.methods.VAT | Applies VAT method. |
| semipy.methods.AdaMatch | Applies AdaMatch metehod. |
| semipy.methods.utils.EMA | For using Exponential Moving Average on validation model's parameters. |
| semipy.methods.utils.DistAlign | To perform Distribution Alignment. |
| semipy.methods.utils.EarlyStopping | To monitor a specific metric in order to apply earlystopping. |
Functions
| API Reference | Description |
|---|---|
| semipy.methods.utils.get_metrics | Reads a metrics YAML file and outputs a dicitonary of metrics. |
| semipy.methods.functional.pseudolabel_loss | To compute PseudoLabel loss. |
| semipy.methods.functional.fixmatch_loss | To compute FixMatch loss. |
| semipy.methods.functional.pimodel_loss | To compute PiModel loss. |
| semipy.methods.functional.vat_loss | To compute VAT loss. |
| semipy.methods.functional.adamatch_loss | To compute AdaMatch loss. |
semipy.pl
Warning
This section is in construction.
This module implements the different methods listed in semipy.methods but with PyTorch Lightning support.
| API Reference | Description |
|---|---|
| semipy.pl.LitCompleteCase | To use the complete case with PyTorch Lightning. |
| semipy.pl.LitPseudoLabel | To use PseudoLabel method with PyTorch Lightning. |
| semipy.pl.LitFixMatch | To use FixMatch method with PyTorch Lightning. |
| semipy.pl.LitPiModel | To use PiModel method with PyTorch Lightning. |
| semipy.pl.LitVAT | To use VAT method with PyTorch Lightning. |
| semipy.pl.LitAdaMatch | To use AdaMatch method with PyTorch Lightning. |
semipy.models
Warning
This section is in construction.
This module provides an easy tool to get models from PyTorch, but also adds new WideResNets not included in PyTorch.
Classes
| API Reference | Description |
|---|---|
| semipy.models.WideResNet | To create specific Wide ResNets that are not present in PyTorch. |
Functions
| API Reference | Description |
|---|---|
| semipy.models.get_model | To retrieve any pre-trained (or not) model from PyTorch. |
semipy.sampler
Warning
This section is in construction.
This module implements a useful sampler for SSL. It allows to create batches made of both labelled and unlabelled samples with respect to parameters such as batch size or labelled ratio.
| API Reference | Description |
|---|---|
| semipy.sampler.JointSampler | SSL sampler to create batches made of both labelled and unlabelled items. |
| semipy.sampler.DistributedJointSampler | SSL sampler based on JointSampler adapted for DDP multi-GPU. |
semipy.transforms
Warning
This section is in construction.
This module is dedicated to data transformation. Here developpers can add custom transformations in the form of a new class. As and example, a 'Cutout' class has been created as it is not present in torchvision.transforms but often used in SSL benchmarks.
| API Reference | Description |
|---|---|
| semipy.transforms.Cutout | To apply Cutout on images for data augmentation. |
semipy.tools
Warning
This section is in construction.
This module provides useful tools for the proper functioning of the library.
Classes
| API Reference | Description |
|---|---|
| semipy.tools.SSLTrainer | Used to create a Train/Validation/Test 'environment' without PyTorch Lightning. |
Functions
| API Reference | Descritpion |
|---|---|
| semipy.tools.get_config | To read parameters from a YAML file. |
| semipy.tools.get_optimizer | To get the right optimizer with the correct parameters from torch.optim. |
| semipy.tools.get_cosine_schedule_with_warmup | To get a cosine scheduler. |